{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OlDwW4HY8MoU"
   },
   "source": [
    "# Model understanding and interpretability"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2pFlZCUv7hM-"
   },
   "source": [
    "In this colab, we will \n",
    "- Will learn how to interpret model results and reason about the features\n",
    "- Visualize the model results\n",
    "\n",
    "Please complete the exercises and answer the questions tagged **`???`**.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rYhQ7yT_q0QR"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "# We will use some np and pandas for dealing with input data.\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "# And of course, we need tensorflow.\n",
    "import tensorflow as tf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CkDZWpif2Pko"
   },
   "source": [
    "Below we demonstrate both *local* and *global* model interpretability for gradient boosted trees. \n",
    "\n",
    "Local interpretability refers to an understanding of a model’s predictions at the individual example level, while global interpretability refers to an understanding of the model as a whole.\n",
    "\n",
    "For local interpretability, we show how to create and visualize per-instance contributions using the technique outlined in [Palczewska et al](https://arxiv.org/pdf/1312.1121.pdf) and by Saabas in [Interpreting Random Forests](http://blog.datadive.net/interpreting-random-forests/) (this method is also available in scikit-learn for Random Forests in the [`treeinterpreter`](https://github.com/andosa/treeinterpreter) package). To distinguish this from feature importances, we refer to these values as directional feature contributions (DFCs).\n",
    "\n",
    "For global interpretability we show how to retrieve and visualize gain-based feature importances, [permutation feature importances](https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf) and also show aggregated DFCs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vfxkZE-MaY0h"
   },
   "source": [
    "# Setup\n",
    "## Load dataset\n",
    "We will be using the titanic dataset, where the goal is to predict passenger survival given characteristiscs such as gender, age, class, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gd995mWZzOTz"
   },
   "outputs": [],
   "source": [
    "tf.logging.set_verbosity(tf.logging.ERROR)\n",
    "tf.set_random_seed(123)\n",
    "\n",
    "# Load dataset.\n",
    "dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n",
    "dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n",
    "y_train = dftrain.pop('survived')\n",
    "y_eval = dfeval.pop('survived')\n",
    "\n",
    "# Feature columns.\n",
    "fcol = tf.feature_column\n",
    "CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',\n",
    "                       'embark_town', 'alone']\n",
    "NUMERIC_COLUMNS = ['age', 'fare']\n",
    "\n",
    "def one_hot_cat_column(feature_name, vocab):\n",
    "  return fcol.indicator_column(\n",
    "      fcol.categorical_column_with_vocabulary_list(feature_name,\n",
    "                                                 vocab))\n",
    "fc = []\n",
    "for feature_name in CATEGORICAL_COLUMNS:\n",
    "  # Need to one-hot encode categorical features.\n",
    "  vocabulary = dftrain[feature_name].unique()\n",
    "  fc.append(one_hot_cat_column(feature_name, vocabulary))\n",
    "\n",
    "for feature_name in NUMERIC_COLUMNS:\n",
    "  fc.append(fcol.numeric_column(feature_name,\n",
    "                                dtype=tf.float32))\n",
    "\n",
    "# Input functions.\n",
    "def make_input_fn(X, y, n_epochs=None):\n",
    "  def input_fn():\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'), y))\n",
    "    # For training, cycle thru dataset as many times as need (n_epochs=None).\n",
    "    dataset = (dataset\n",
    "      .repeat(n_epochs)\n",
    "      .batch(len(y)))  # Use entire dataset since this is such a small dataset.\n",
    "    return dataset\n",
    "  return input_fn\n",
    "\n",
    "# Training and evaluation input functions.\n",
    "train_input_fn = make_input_fn(dftrain, y_train)\n",
    "eval_input_fn = make_input_fn(dfeval, y_eval, n_epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5bgBDHLJsGnb"
   },
   "source": [
    "# Interpret model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CnlBxy2X2Txj"
   },
   "source": [
    "## Local interpretability\n",
    "Output directional feature contributions (DFCs) to explain individual predictions, using the approach outlined in [Palczewska et al](https://arxiv.org/pdf/1312.1121.pdf) and by Saabas in [Interpreting Random Forests](http://blog.datadive.net/interpreting-random-forests/). The DFCs are generated with:\n",
    "\n",
    "`pred_dicts = list(est.experimental_predict_with_explanations(pred_input_fn))`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 386
    },
    "colab_type": "code",
    "id": "v9vxxoCD9uxR",
    "outputId": "dd4384da-3ec5-446f-e055-e46f2ab20c60"
   },
   "outputs": [],
   "source": [
    "params = {\n",
    "  'n_trees': 50,\n",
    "  'max_depth': 3,\n",
    "  'n_batches_per_layer': 1,\n",
    "  # You must enable center_bias = True to get DFCs. This will force the model to\n",
    "  # make an initial prediction before using any features (e.g. use the mean of\n",
    "  # the training labels for regression or log odds for classification when\n",
    "  # using cross entropy loss).\n",
    "  'center_bias': True\n",
    "}\n",
    "\n",
    "est = tf.estimator.BoostedTreesClassifier(fc, **params)\n",
    "# Train model.\n",
    "est.train(train_input_fn)\n",
    "\n",
    "# Evaluation.\n",
    "results = est.evaluate(eval_input_fn)\n",
    "clear_output()\n",
    "pd.Series(results).to_frame()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8R-hJcShzTtL"
   },
   "source": [
    "## Local interpretability\n",
    "Next you will output the directional feature contributions (DFCs) to explain individual predictions using the approach outlined in [Palczewska et al](https://arxiv.org/pdf/1312.1121.pdf) and by Saabas in [Interpreting Random Forests](http://blog.datadive.net/interpreting-random-forests/) (this method is also available in scikit-learn for Random Forests in the [`treeinterpreter`](https://github.com/andosa/treeinterpreter) package). The DFCs are generated with:\n",
    "\n",
    "`pred_dicts = list(est.experimental_predict_with_explanations(pred_input_fn))`\n",
    "\n",
    "(Note: The method is named experimental as we may modify the API before dropping the experimental prefix.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ji0U4FsNzROH"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns_colors = sns.color_palette('colorblind')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JTWBVzmvtctG"
   },
   "outputs": [],
   "source": [
    "pred_dicts = list(est.experimental_predict_with_explanations(eval_input_fn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bu5HuGRE61B2"
   },
   "outputs": [],
   "source": [
    "def clean_feature_names(df):\n",
    "    \"\"\"Boilerplate code to cleans up feature names -- this is unneed in TF 2.0\"\"\"\n",
    "    df.columns = [v.split(':')[0].split('_indi')[0] for v in df.columns.tolist()]\n",
    "    df = df.T.groupby(level=0).sum().T\n",
    "    return df\n",
    "\n",
    "# Create DFC Pandas dataframe.\n",
    "labels = y_eval.values\n",
    "probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])\n",
    "df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])\n",
    "df_dfc.columns = est._names_for_feature_id\n",
    "df_dfc = clean_feature_names(df_dfc)\n",
    "df_dfc.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "x7M8qssg65jd"
   },
   "outputs": [],
   "source": [
    "# Sum of DFCs + bias == probabality.\n",
    "bias = pred_dicts[0]['bias']\n",
    "dfc_prob = df_dfc.sum(axis=1) + bias\n",
    "np.testing.assert_almost_equal(dfc_prob.values,\n",
    "                               probs.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0X9h0iq6_7uN"
   },
   "source": [
    "Plot results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exercise: Plot figures for multiple examples. How would you explain each plot in plain english?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "93d4Ccx368GH"
   },
   "outputs": [],
   "source": [
    "import seaborn as sns  # Make plotting nicer.\n",
    "sns_colors = sns.color_palette('colorblind')\n",
    "\n",
    "def plot_dfcs(example_id):\n",
    "    label, prob = labels[ID], probs[ID]\n",
    "    example = df_dfc.iloc[ID]  # Choose ith example from evaluation set.\n",
    "    TOP_N = 8  # View top 8 features.\n",
    "    sorted_ix = example.abs().sort_values()[-TOP_N:].index\n",
    "    ax = example[sorted_ix].plot(kind='barh', color='g', figsize=(10,5))\n",
    "    ax.grid(False, axis='y')\n",
    "\n",
    "    plt.title('Feature contributions for example {}\\n pred: {:1.2f}; label: {}'.format(ID, prob, label))\n",
    "    plt.xlabel('Contribution to predicted probability')\n",
    "\n",
    "ID = 102  # Change this.\n",
    "plot_dfcs(ID)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tx5p4vEhuczg"
   },
   "source": [
    "### Prettier plotting\n",
    "Color codes based on directionality and adds feature values on figure. Please do not worry about the details of the plotting code :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6z_Tq1Pquczj"
   },
   "outputs": [],
   "source": [
    "def plot_example_pretty(example):\n",
    "    \"\"\"Boilerplate code for better plotting :)\"\"\"\n",
    "    \n",
    "    def _get_color(value):\n",
    "        \"\"\"To make positive DFCs plot green, negative DFCs plot red.\"\"\"\n",
    "        green, red = sns.color_palette()[2:4]\n",
    "        if value >= 0: return green\n",
    "        return red\n",
    "\n",
    "\n",
    "    def _add_feature_values(feature_values, ax):\n",
    "        \"\"\"Display feature's values on left of plot.\"\"\"\n",
    "        x_coord = ax.get_xlim()[0]\n",
    "        OFFSET = 0.15\n",
    "        for y_coord, (feat_name, feat_val) in enumerate(feature_values.items()):\n",
    "            t = plt.text(x_coord, y_coord - OFFSET, '{}'.format(feat_val), size=12)\n",
    "            t.set_bbox(dict(facecolor='white', alpha=0.5))\n",
    "        from matplotlib.font_manager import FontProperties\n",
    "        font = FontProperties()\n",
    "        font.set_weight('bold')\n",
    "        t = plt.text(x_coord, y_coord + 1 - OFFSET, 'feature\\nvalue',\n",
    "        fontproperties=font, size=12)\n",
    "\n",
    "\n",
    "    TOP_N = 8 # View top 8 features.\n",
    "    sorted_ix = example.abs().sort_values()[-TOP_N:].index  # Sort by magnitude.\n",
    "    example = example[sorted_ix]\n",
    "    colors = example.map(_get_color).tolist()\n",
    "    ax = example.to_frame().plot(kind='barh',\n",
    "                          color=[colors],\n",
    "                          legend=None,\n",
    "                          alpha=0.75,\n",
    "                          figsize=(10,6))\n",
    "    ax.grid(False, axis='y')\n",
    "    ax.set_yticklabels(ax.get_yticklabels(), size=14)\n",
    "    _add_feature_values(dfeval.iloc[ID].loc[sorted_ix], ax)\n",
    "    ax.set_title('Feature contributions for example {}\\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]))\n",
    "    ax.set_xlabel('Contribution to predicted probability', size=14)\n",
    "    plt.show()\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ht1P2-1euczk"
   },
   "outputs": [],
   "source": [
    "# Plot results.\n",
    "ID = 102\n",
    "example = df_dfc.iloc[ID]  # Choose ith example from evaluation set.\n",
    "ax = plot_example_pretty(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CvF6vrEA7oQm"
   },
   "source": [
    "## Global feature importances\n",
    "\n",
    "1. Gain-based feature importances using `est.experimental_feature_importances`\n",
    "2. Aggregate DFCs using `est.experimental_predict_with_explanations`\n",
    "3. Permutation importances\n",
    "\n",
    "Gain-based feature importances measure the loss change when splitting on a particular feature, while permutation feature importances are computed by evaluating model performance on the evaluation set by shuffling each feature one-by-one and attributing the change in model performance to the shuffled feature.\n",
    "\n",
    "In general, permutation feature importance are preferred to gain-based feature importance, though both methods can be unreliable in situations where potential predictor variables vary in their scale of measurement or their number of categories and when features are correlated ([source](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-9-307)). Check out [this article](http://explained.ai/rf-importance/index.html) for an in-depth overview and great discussion on different feature importance types.\n",
    "\n",
    "## 1. Gain-based feature importances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QuN7GU0J7MYr"
   },
   "outputs": [],
   "source": [
    "features, importances = est.experimental_feature_importances(normalize=True)\n",
    "df_imp = pd.DataFrame(importances, columns=['importances'], index=features)\n",
    "# For plotting purposes. This is not needed in TF 2.0.\n",
    "df_imp = clean_feature_names(df_imp.T).T.sort_values('importances', ascending=False)\n",
    "\n",
    "# Visualize importances.\n",
    "N = 8\n",
    "ax = df_imp.iloc[0:N][::-1]\\\n",
    "    .plot(kind='barh',\n",
    "          color=sns_colors[0],\n",
    "          title='Gain feature importances',\n",
    "          figsize=(10, 6))\n",
    "ax.grid(False, axis='y')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3RIx6JliBB9Z"
   },
   "source": [
    "**???** What does the x axis represent?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Be0-JLD1Bp2k"
   },
   "source": [
    "**???** Can we completely trust these results and the magnitudes?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mixDJm2P_rSW"
   },
   "source": [
    "### 2. Average absolute DFCs\n",
    "We can also average the absolute values of DFCs to understand impact at a global level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WzfgBWvJ_q3F"
   },
   "outputs": [],
   "source": [
    "# Plot.\n",
    "dfc_mean = df_dfc.abs().mean()\n",
    "sorted_ix = dfc_mean.abs().sort_values()[-8:].index  # Average and sort by absolute.\n",
    "ax = dfc_mean[sorted_ix].plot(kind='barh',\n",
    "                       color=sns_colors[1],\n",
    "                       title='Mean |directional feature contributions|',\n",
    "                       figsize=(10, 6))\n",
    "ax.grid(False, axis='y')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3ZIwZnJHAGm-"
   },
   "source": [
    "We can also see how DFCs vary as a feature value varies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tX8c1-k4AIZU"
   },
   "outputs": [],
   "source": [
    "age = pd.Series(df_dfc.age.values, index=dfeval.age.values).sort_index()\n",
    "sns.jointplot(age.index.values, age.values);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ZqhB3WlDUuIL"
   },
   "source": [
    "# Visualizing the model's prediction surface"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Jd9hQrAUzjYb"
   },
   "source": [
    "Lets first simulate/create training data using the following formula:\n",
    "\n",
    "\n",
    "$z=x* e^{-x^2 - y^2}$\n",
    "\n",
    "\n",
    "Where $z$ is the dependent variable we are trying to predict and $x$ and $y$ are the features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "g54QgI2MB1-C"
   },
   "outputs": [],
   "source": [
    "from numpy.random import uniform, seed\n",
    "from matplotlib.mlab import griddata\n",
    "\n",
    "# Create fake data\n",
    "seed(0)\n",
    "npts = 5000\n",
    "x = uniform(-2, 2, npts)\n",
    "y = uniform(-2, 2, npts)\n",
    "z = x*np.exp(-x**2 - y**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SSE5NGH0D12J"
   },
   "outputs": [],
   "source": [
    "# Prep data for training.\n",
    "df = pd.DataFrame({'x': x, 'y': y, 'z': z})\n",
    "\n",
    "xi = np.linspace(-2.0, 2.0, 200),\n",
    "yi = np.linspace(-2.1, 2.1, 210),\n",
    "xi,yi = np.meshgrid(xi, yi)\n",
    "\n",
    "df_predict = pd.DataFrame({\n",
    "    'x' : xi.flatten(),\n",
    "    'y' : yi.flatten(),\n",
    "})\n",
    "predict_shape = xi.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eH-GHdGvD39j"
   },
   "outputs": [],
   "source": [
    "def plot_contour(x, y, z, **kwargs):\n",
    "  # Grid the data.\n",
    "  plt.figure(figsize=(10, 8))\n",
    "  # Contour the gridded data, plotting dots at the nonuniform data points.\n",
    "  CS = plt.contour(x, y, z, 15, linewidths=0.5, colors='k')\n",
    "  CS = plt.contourf(x, y, z, 15,\n",
    "                    vmax=abs(zi).max(), vmin=-abs(zi).max(), cmap='RdBu_r')\n",
    "  plt.colorbar()  # Draw colorbar.\n",
    "  # Plot data points.\n",
    "  plt.xlim(-2, 2)\n",
    "  plt.ylim(-2, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BSTgrzVJ0vGn"
   },
   "source": [
    "We can visualize our function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KQUOxy5e0uel"
   },
   "outputs": [],
   "source": [
    "zi = griddata(x, y, z, xi, yi, interp='linear')\n",
    "plot_contour(xi, yi, zi)\n",
    "plt.scatter(df.x, df.y, marker='.')\n",
    "plt.title('Contour on training data')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T6i9EuURCOyD"
   },
   "outputs": [],
   "source": [
    "def predict(est):\n",
    "  \"\"\"Predictions from a given estimator.\"\"\"\n",
    "  predict_input_fn = lambda: tf.data.Dataset.from_tensors(dict(df_predict))\n",
    "  preds = np.array([p['predictions'][0] for p in est.predict(predict_input_fn)])\n",
    "  return preds.reshape(predict_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KPTN9zIxELAJ"
   },
   "source": [
    "First let's try to fit a linear model to the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "z6OH5EbsGeL9"
   },
   "outputs": [],
   "source": [
    "fc = [tf.feature_column.numeric_column('x'),\n",
    "      tf.feature_column.numeric_column('y')]\n",
    "\n",
    "train_input_fn = make_input_fn(df, df.z)\n",
    "est = tf.estimator.LinearRegressor(fc)\n",
    "est.train(train_input_fn, max_steps=500);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4x5tKvCtEPxy"
   },
   "outputs": [],
   "source": [
    "plot_contour(xi, yi, predict(est))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rkjrUIkz1kEL"
   },
   "source": [
    "Not very good at all..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "37SuxcCq2NbU"
   },
   "source": [
    "**???** Why is the linear model not performing well for this problem? Can you think of how to improve it just using a linear model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DVX-kuee1GZ0"
   },
   "source": [
    "Next let's try to fit a GBDT model to it and try to understand what the model does"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_IVuaW_9CPGM",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for n_trees in [1,2,3,10,30,50,100,200]:\n",
    "  est = tf.estimator.BoostedTreesRegressor(fc,\n",
    "                                          n_batches_per_layer=1,\n",
    "                                          max_depth=4,\n",
    "                                          n_trees=n_trees)\n",
    "  est.train(train_input_fn)\n",
    "  plot_contour(xi, yi, predict(est))\n",
    "  plt.text(-1.8, 2.1, '# trees: {}'.format(n_trees), color='w', backgroundcolor='black', size=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HUqeJGsmGywQ"
   },
   "source": [
    "A lot better. The model is learning to approximate the true loss surface. But I think we can do better."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jw0HPcM216ak"
   },
   "source": [
    "**??? Exercise** Can you get a better fit?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TV9HjfriHIXw"
   },
   "source": [
    "**??? BONUS**: If you have time, try making your own functions to learn and visualize them above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2019 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "ASL_c_boosted_trees_model_understanding",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
